# coding=utf-8
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
# Copyright 2019 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implements Svhn data class."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow_datasets as tfds

from . import base as base
from .registry import Registry
# This constant specifies the percentage of data that is used to create custom
# train/val splits. Specifically, TRAIN_SPLIT_PERCENT% of the official training
# split is used as a new training split and the rest is used for validation.
TRAIN_SPLIT_PERCENT = 90


@Registry.register("data.svhn", "class")
class SvhnData(base.ImageTfdsData):
  """Provides SVHN data.

  The Street View House Numbers (SVHN) Dataset is an image digit recognition
  dataset of over 600,000 color digit images coming from real world data.
  Split size:
    - Training set: 73,257 images
    - Testing set: 26,032 images
    - Extra training set: 531,131 images
  Following the common setup on SVHN, we only use the official training and
  testing data. Images are cropped to 32x32.

  URL: http://ufldl.stanford.edu/housenumbers/
  """

  def __init__(self, data_dir=None):
    dataset_builder = tfds.builder("svhn_cropped:3.*.*", data_dir=data_dir)
    dataset_builder.download_and_prepare()

    # Example counts are retrieved from the tensorflow dataset info.
    trainval_count = dataset_builder.info.splits[tfds.Split.TRAIN].num_examples
    test_count = dataset_builder.info.splits[tfds.Split.TEST].num_examples

    origin = True
    if origin:
        # Creates a dict with example counts for each split.
        num_samples_splits = {
            # Calculates the train/val split example count based on percent.
            "train": TRAIN_SPLIT_PERCENT * trainval_count // 100,
            "val": trainval_count - TRAIN_SPLIT_PERCENT * trainval_count // 100,
            "trainval": trainval_count,
            "test": test_count,
            "train800": 800,
            "val200": 200,
            "train800val200": 1000,
        }
        # num_samples_splits:  {'train': 65931, 'val': 7326, 'trainval': 73257, 'test': 26032, 'train800': 800, 'val200': 200, 'train800val200': 1000}
        # print("num_samples_splits: ", num_samples_splits)

        # Defines dataset specific train/val/trainval/test splits.
        # The validation set is split out of the original training set, and the
        # remaining examples are used as the "train" split. The "trainval" split
        # corresponds to the original training set.
        tfds_splits = {
            "train":
                "train[:{}]".format(num_samples_splits["train"]),
            "val":
                "train[{}:]".format(num_samples_splits["train"]),
            "trainval":
                "train",
            "test":
                "test",
            "train800":
                "train[:800]",
            "val200":
                "train[{}:{}]".format(num_samples_splits["train"],
                                    num_samples_splits["train"] + 200),
            "train800val200":
                "train[:800]+train[{}:{}]".format(
                    num_samples_splits["train"], num_samples_splits["train"] + 200),
        }
        # tfds_splits:  {'train': 'train[:65931]', 'val': 'train[65931:]', 'trainval': 'train', 'test': 'test', 'train800': 'train[:800]', 'val200': 'train[65931:66131]', 'train800val200': 'train[:800]+train[65931:66131]'}
        # print("tfds_splits: ", tfds_splits)
    
    else:
        # Creates a dict with example counts for each split.
        num_samples_splits = {
            # Calculates the train/val split example count based on percent.
            "train": TRAIN_SPLIT_PERCENT * trainval_count // 100,
            "val": trainval_count - TRAIN_SPLIT_PERCENT * trainval_count // 100,
            "trainval": trainval_count,
            "test": test_count,
            "train800": 10000,
            "val200": 2000,
            "train800val200": 12000,
        }

        # Defines dataset specific train/val/trainval/test splits.
        # The validation set is split out of the original training set, and the
        # remaining examples are used as the "train" split. The "trainval" split
        # corresponds to the original training set.
        tfds_splits = {
            "train":
                "train[:{}]".format(num_samples_splits["train"]),
            "val":
                "train[{}:]".format(num_samples_splits["train"]),
            "trainval":
                "train",
            "test":
                "test",
            "train800":
                "train[:10000]",
            "val200":
                "train[{}:{}]".format(num_samples_splits["train"],
                                    num_samples_splits["train"] + 2000),
            "train800val200":
                "train[:10000]+train[{}:{}]".format(
                    num_samples_splits["train"], num_samples_splits["train"] + 2000),
        }
        # tfds_splits:  {'train': 'train[:65931]', 'val': 'train[65931:]', 'trainval': 'train', 'test': 'test', 'train800': 'train[:800]', 'val200': 'train[65931:66131]', 'train800val200': 'train[:800]+train[65931:66131]'}
        # print("tfds_splits: ", tfds_splits)

    super(SvhnData, self).__init__(
        dataset_builder=dataset_builder,
        tfds_splits=tfds_splits,
        num_samples_splits=num_samples_splits,
        num_preprocessing_threads=400,
        shuffle_buffer_size=10000,
        # Note: Rename tensors but keep their original types.
        base_preprocess_fn=base.make_get_and_cast_tensors_fn({
            "image": ("image", None),
            "label": ("label", None),
        }),
        num_classes=dataset_builder.info.features["label"]
        .num_classes)
